import argparse
import asyncio
import random
from collections import defaultdict
from functools import partial
from typing import Optional

import orjsonl
from chainlite import (
    chain,
    get_total_cost,
    llm_generation_chain,
    write_prompt_logs_to_file,
)
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from tqdm import tqdm

from pydantic import BaseModel, Field
from schema_definition import (
    acled_subevent_class_to_event_class_map,
)
from utils import (
    extract_classes_info,
    extract_tag_from_llm_output,
    format_python_code,
    run_async_in_parallel,
    string_to_event_object,
    string_to_list,
)


class EntitySpan(BaseModel):
    entity_original: str = Field(
        ...,
        description="The original entity name, copied from the list of given entities in the input",
    )
    span: Optional[str] = Field(
        ...,
        description="The exact text from the article that mentions the entity",
    )

    start_index: Optional[int] = Field(
        ...,
        description="The character index where the entity span starts in the article",
    )

    end_index: Optional[int] = Field(
        ...,
        description="The character index where the entity span ends in the article",
    )

    explanation: Optional[str] = Field(
        ...,
        description="Explanation of why this span matches the given entity",
    )


class AllEntitySpans(BaseModel):
    spans: list[EntitySpan]

@chain
def process_entity_output(all_spans: AllEntitySpans) -> dict[str, set]:
    ret = {}
    if not all_spans or getattr(all_spans, "spans", None) is None:
        return ret
    for s in all_spans.spans:
        original_actor = s.entity_original
        if original_actor not in ret:
            ret[original_actor] = []
        if s.span:
            ret[original_actor].append(s.span)

    return ret

console = Console()

from log_utils import get_logger

logger = get_logger(__name__)


def preds_to_string(p1: str, p2: str) -> str:
    return ",".join(list(sorted([p1, p2])))  # sort to normalize the dictionary keys


@chain
def convert_to_list(llm_output: str):
    # print("llm_output = ", llm_output)
    if llm_output.endswith("."):
        llm_output = llm_output[:-1]
    # Remove square brackets and split by '>'
    items = llm_output.replace("[", "").replace("]", "").split(">")

    items = [item.strip() for item in items]

    return items


event_detection_chain = None
event_actor_detector_chain = None
actor_extraction_chain = None


def calculate_top_k(predictions, gold_outputs, max_k=10):
    # all_conflicting_event_types = {}
    max_len_predictions = max(len(pred) for pred in predictions)
    max_k = min(max_k, max_len_predictions)

    top_k_accuracy = [0] * max_k
    for pred, gold in zip(predictions, gold_outputs):
        is_correct_list = [0] * max_k
        for k in range(max_k):
            if k < len(pred):
                is_correct = int(pred[k] == gold)
            else:
                is_correct = 0
            if k > 0 and is_correct_list[k - 1] == 1:
                is_correct = 1
            is_correct_list[k] = is_correct
            top_k_accuracy[k] += is_correct

    top_k_accuracy = [a / len(gold_outputs) * 100 for a in top_k_accuracy]
    for k, acc in enumerate(top_k_accuracy, start=1):
        print(f"Recall@{k}: {acc:.1f}")


def load_input_file(args):
    data = orjsonl.load(args.input_file)

    if args.language:
        data = [item for item in data if item["language"] == args.language]
        if args.per_language_size > 0:
            data = random.sample(data, min(args.per_language_size, len(data)))
    else:
        if args.per_language_size > 0:
            # group by language
            data_by_language = defaultdict(list)
            for item in data:
                data_by_language[item["language"]].append(item)
            new_data = []
            for lang, items in data_by_language.items():
                new_data.extend(
                    random.sample(items, min(args.per_language_size, len(items)))
                )

    outputs = []
    for item in data:
        article = item["plain_text"]
        event_object = string_to_event_object(item["gold_label"])
        outputs.append(
            {
                "id": item["acled_id"],
                "article": article,
                "gold_event_object": event_object,
                "language": item["language"],
            }
        )

    return outputs


def load_entity_file(
    entity_description_file: str = "dataset/entity_descriptions.jsonl",
):
    all_entity_descriptions = orjsonl.load(entity_description_file)
    entity_dict = {}
    for actor in all_entity_descriptions:
        entity_dict[actor["entity_name"]] = actor["description"]

    return entity_dict


def load_event_definitions():
    event_types_info = extract_classes_info(include_abstract=True)
    all_event_definitions = {}

    # print(all_event_definitions)
    filtered_event_definitions = []
    for event_type, description in event_types_info.items():
        if (
            event_type not in acled_subevent_class_to_event_class_map.values()
            and event_type != "ACLEDEvent"
        ):
            filtered_event_definitions.append((event_type, description))

    all_event_definitions = dict(filtered_event_definitions)
    for event_type in all_event_definitions:
        all_event_definitions[event_type] = all_event_definitions[event_type][
            "docstring"
        ]

    return all_event_definitions, event_types_info


async def predict_event_type(
    articles: list[str],
    event_definitions: list[dict],
) -> list[list[str]]:
    assert isinstance(event_definitions, list)
    assert len(articles) == len(event_definitions)

    predicted_labels = [None] * len(articles)
    chain_inputs = []
    for a, ed in zip(articles, event_definitions):
        if len(ed) > 1:
            chain_inputs.append({"article": a, "event_definitions": ed})
    predictions = await event_detection_chain.abatch(
        chain_inputs, config={"max_concurrency": 200}
    )
    predicted_labels = []
    i = 0
    for ed in event_definitions:
        if len(ed) > 1:
            pred = predictions[i]
            pred = [
                p for p in pred if p in ed
            ]  # filter the occasional invalid predicted label

            predicted_labels.append(pred)
            i += 1
        else:
            predicted_labels.append(list(ed.keys()))

    assert len(predicted_labels) == len(articles)

    return predicted_labels


async def do_event_argument_extraction(data, all_event_types_info, engine):
    eae_chains = {}
    for event_type in all_event_types_info:
        eae_chains[event_type] = llm_generation_chain(
            template_file="eae",
            template_blocks=[
                (
                    "instruction",
                    """You are an AI assistant tasked with extracting event arguments from a given news article. You will be provided with annotation guidelines for an event type and a news article to analyze.""",
                ),
                (
                    "input",
                    """{{ article }}. Extract the arguments of the main event in this article, which is of type {{ event_type }}.
                    For "entity" arguments, note that an entity can be a generic term like "Students" or "Protestors", or specific political groups, militia, armed groups, etc. Never use an individual's name as an entity.
                    
                    Sometimes, a specific entity is accompanied with a generic entity. For example if a political party is leading a protest, both the political party's name and the "Protestors" should be included as entities.
                    When identifying an entity, provide as much information as possible.""",
                ),
            ],
            engine=engine,
            max_tokens=1000,
            pydantic_class=all_event_types_info[event_type]["class_obj"],
        )
        # print(all_event_types_info[event_type]["class_obj"].model_json_schema())

    eae_chain_inputs = defaultdict(list)
    outputs = []
    for row_id, row in enumerate(data):
        outputs.append(None)
        event_type = row["predicted_event_type"]
        assert event_type in eae_chains
        eae_chain_inputs[event_type].append(
            (
                row_id,
                {"article": row["article"], "event_type": event_type},
            )
        )

    for event_type in tqdm(
        eae_chain_inputs,
        desc="Extracting event arguments for each event type",
    ):
        row_idx, chain_inputs = zip(*eae_chain_inputs[event_type])
        chain_outputs = await eae_chains[event_type].abatch(chain_inputs)
        for i, o in zip(row_idx, chain_outputs):
            outputs[i] = o

    return outputs



async def do_event_detection(args, all_event_definitions, data):
    articles = [row["article"] for row in data]

    ensemble_predicted_event_types: list[list[str]] = await predict_event_type(
        articles,
        [all_event_definitions] * len(articles),
    )

    gold_labels = [row["gold_event_object"].get_event_type() for row in data]

    logger.info("After first round of ranking event types:")
    calculate_top_k(
        ensemble_predicted_event_types,
        gold_labels,
    )

    return ensemble_predicted_event_types


async def bulk_filter_entities(
    entity_names,
    entity_dict,
    event_object,
    article,
    shard_size,
) -> list[str]:
    entity_candidates = defaultdict(int)
    entity_names = list(entity_names)
    entity_shards = [
        entity_names[i : i + shard_size]
        for i in range(0, len(entity_names), shard_size)
    ]

    shard_entity_candidates = await event_actor_detector_chain.abatch(
        [
            {
                "article": article,
                "potential_entities": (
                    [f"{a}: {entity_dict[a]}" for a in shard]
                    if entity_dict
                    else shard
                ),
                "country": event_object.location.country,
                "event": event_object.get_event_type(),
            }
            for shard in entity_shards
        ]
    )
    # flatten the list and convert to set
    entity_set = set(
        [item for sublist in shard_entity_candidates for item in sublist]
    )
    for entity in entity_set:
        entity_candidates[entity] += 1

    # only keep the entities that are in the database
    not_in_db = set(entity_candidates.keys()) - set(entity_names)
    if not_in_db:
        logger.info(
            "entities not in the database = %s",
            not_in_db,
        )
    entity_candidates = dict(
        [
            (entity, count)
            for entity, count in entity_candidates.items()
            if entity in entity_names
        ]
    )

    entity_candidates = [
        entity for entity, count in entity_candidates.items() if count > 0
    ]

    return set(entity_candidates)


async def extractive_filter_entities(
    entity_names,
    entity_dict,
    event_object,
    article,
    shard_size,
) -> list[str]:
    entity_candidates = defaultdict(int)
    entity_names = list(entity_names)
    entity_shards = [
        entity_names[i : i + shard_size]
        for i in range(0, len(entity_names), shard_size)
    ]

    shard_entity_candidates = await actor_extraction_chain.abatch(
        [
            {
                "source_text_cleaned": article,
                "sub_event_type": event_object.get_event_type(),
                "entities": (
                    [f"{a}: {entity_dict[a]}" for a in shard]
                    if entity_dict
                    else shard
                ),
            }
            for shard in entity_shards
        ]
    )

    # find all entities that had evidence in the article
    entity_set = set()
    for shard in shard_entity_candidates:
        assert isinstance(shard, dict)
        for k, v in shard.items():
            if v:
                entity_set.add(k)

    # logger.info("entity_set = %s", entity_set)
    for entity in entity_set:
        entity_candidates[entity] += 1

    for i in range(len(entity_shards)):
        mid = len(entity_shards[i]) // 2
        entity_shards[i] = entity_shards[i][mid + 1 :] + entity_shards[i][:mid]

    # only keep the entities that are in the database
    not_in_db = set(entity_candidates.keys()) - set(entity_names)
    if not_in_db:
        logger.info(
            "entities not in the database = %s",
            not_in_db,
        )

    entity_candidates = dict(
        [
            (entity, count)
            for entity, count in entity_candidates.items()
            if entity in entity_names
        ]
    )

    entity_candidates = [
        entity for entity, count in entity_candidates.items() if count > 0
    ]

    return set(entity_candidates)


async def do_entity_linking(
    row: dict,
    entity_dict: dict,
    shard_size=63,  # ~100 shards
):
    # row has 'id', 'article', 'gold_event_object', 'predicted_event_type', 'predicted_event'

    event_object = row["predicted_event"]

    # Round 1 with all entities
    entity_candidates = await bulk_filter_entities(
        entity_names=list(entity_dict.keys()),
        entity_dict=None,
        event_object=event_object,
        article=row["article"],
        shard_size=shard_size,
    )

    # Round 2, with fewer entities and their descriptions
    entity_candidates = await extractive_filter_entities(
        entity_names=entity_candidates,
        entity_dict=entity_dict,
        event_object=event_object,
        article=row["article"],
        shard_size=10,
    )

    return entity_candidates


async def main(args):
    event_definitions, event_types_info = load_event_definitions()
    data = load_input_file(args)

    if args.do_ed:
        new_data = []
        predicted_event_types = await do_event_detection(args, event_definitions, data)
        for row, event_types in zip(data, predicted_event_types):
            # duplicate events for each predicted event type
            for et in event_types[:1]:
                new_data.append({**row, "predicted_event_type": et})
        data = new_data
    else:
        for row, event_types in zip(data, predicted_event_types):
            row["predicted_event_type"] = row["gold_event_object"].get_event_type()

    if args.do_eae:
        eae_outputs = await do_event_argument_extraction(
            data,
            event_types_info,
            engine=args.engine,
        )
        for row, eae_output in zip(data, eae_outputs):
            row["predicted_event"] = eae_output
    else:
        for row in data:
            row["predicted_event"] = row["gold_event_object"]

    if args.do_entity_linking:
        actor_assignment_chains = {}
        for event_type in event_types_info:
            actor_assignment_chains[event_type] = llm_generation_chain(
                "assign_entities.prompt",
                engine=args.engine,
                max_tokens=1000,
                pydantic_class=event_types_info[event_type]["class_obj"],
            )

        entity_dict = load_entity_file()
        entity_linking_outputs = await run_async_in_parallel(
            partial(
                do_entity_linking,
                entity_dict=entity_dict,
            ),
            data,
            max_concurrency=8,
            desc="Entity Linking",
        )

        chain_inputs = defaultdict(list)
        for row, predicted_entities in zip(data, entity_linking_outputs):
            event_without_entities, entity_fields = row[
                "predicted_event"
            ].to_string_without_entities()
            chain_inputs[row["predicted_event_type"]].append(
                (
                    row,
                    {
                        "article": row["article"],
                        "entities": [
                            f'"{a}": {entity_dict[a]}' for a in predicted_entities
                        ],
                        "event": event_without_entities,
                        "entity_fields": entity_fields,
                    },
                )
            )
        for event_type in tqdm(chain_inputs, desc="Assigning entities"):
            rows, ci = zip(*chain_inputs[event_type])
            event_with_entities_assigned = await actor_assignment_chains[
                event_type
            ].abatch(ci)
            for row, e in zip(rows, event_with_entities_assigned):
                row["predicted_event_with_linked_entities"] = e

        false_negatives = 0
        false_positives = 0
        total_gold = 0
        total_pred = 0
        average_output_size = 0

        for row, predicted_entities in zip(data, entity_linking_outputs):
            panel_content = Text()
            panel_content.append("Article:\n", style="bold cyan")
            panel_content.append(f"{row['article']}\n\n", style="cyan")
            panel_content.append("Prediction:\n", style="bold magenta")
            panel_content.append(
                f"{format_python_code(repr(row['predicted_event']))}\n\n",
                style="magenta",
            )
            panel_content.append("Gold Label:\n", style="bold green")
            panel_content.append(
                f"{format_python_code(repr(row['gold_event_object']))}", style="green"
            )

            gold_entities = row["gold_event_object"].get_entities()
            missing_entities = set(gold_entities) - predicted_entities
            false_negatives += len(missing_entities)
            extra_entities = set(predicted_entities) - set(gold_entities)
            false_positives += len(extra_entities)
            total_gold += len(gold_entities)
            total_pred += len(predicted_entities)
            average_output_size += len(predicted_entities)

            panel_content.append("\n\nPredicted Entities:\n", style="bold blue")
            panel_content.append(f"{predicted_entities}", style="blue")
            panel_content.append(
                "\n\nPrediction after actor assignment:\n", style="bold yellow"
            )
            panel_content.append(
                f"{repr(row['predicted_event_with_linked_entities'])}", style="yellow"
            )
            panel_content.append("\n\nMissing Entities:\n", style="bold red")
            panel_content.append(f"{missing_entities}", style="red")
            panel_content.append("\n\nExtra Entities:\n", style="bold red")
            panel_content.append(f"{extra_entities}", style="red")

            panel = Panel(
                panel_content,
                border_style="blue",
                title="Event Argument Extraction",
            )
            if missing_entities or extra_entities:
                console.print(panel)
                console.print("\n")

        recall = (total_gold - false_negatives) / total_gold * 100
        precision = (total_pred - false_positives) / total_pred * 100
        average_output_size = average_output_size / len(entity_linking_outputs)
        logger.info(f"Entity Recall: {recall:.1f}%")
        logger.info(f"Entity Precision: {precision:.1f}%")
        logger.info(f"Entity F1: {2 * recall * precision / (recall + precision):.1f}%")
        logger.info(f"Average number of predicted entities: {average_output_size:.1f}")

    else:
        for row in data:
            row["predicted_event_with_linked_entities"] = row["predicted_event"]

    # combine events with the same id
    grouped_rows = defaultdict(list)
    for row in data:
        grouped_rows[row["id"]].append(row)

    outputs = []
    for rows in grouped_rows.values():
        outputs.append(
            {
                "id": rows[0]["id"],
                "input": rows[0]["article"],
                "output": repr(rows[0]["gold_event_object"]),
                "prediction": repr(
                    rows[0]["predicted_event_with_linked_entities"]
                ),  # the only difference among the group is their predictions
                "language": rows[0]["language"],
            }
        )

    assert len(outputs) == len(grouped_rows)
    orjsonl.save(args.output_file, outputs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_file", type=str, required=True, help="Path to the input jsonl file"
    )
    parser.add_argument(
        "--output_file", type=str, required=True, help="Path to the output jsonl file"
    )
    parser.add_argument(
        "--engine",
        type=str,
        default="gpt-4o",
        help="The LLM to use",
    )
    parser.add_argument(
        "--do_ed",
        action="store_true",
        help="Whether we should do event type detection. If False, will use gold event types.",
    )
    parser.add_argument(
        "--do_eae",
        action="store_true",
        help="Whether we should do event argument extraction after event type detection",
    )
    parser.add_argument(
        "--do_entity_linking",
        action="store_true",
        help="Whether we should do entity linking.",
    )
    parser.add_argument(
        "--language", default=None, type=str, help="Language code to evaluate on"
    )
    parser.add_argument(
        "--per_language_size",
        type=int,
        default=-1,
        help="Number of input samples per language to process. Use -1 to process all samples. Subsampling is done from the end of each language data.",
    )

    args = parser.parse_args()

    random.seed(42)

    event_detection_chain = (
        llm_generation_chain(
            "event_detection.prompt",
            engine=args.engine,
            max_tokens=1000,
            progress_bar_desc="Annotating",
        )
        | extract_tag_from_llm_output.bind(tag="answer")
        | convert_to_list
    )

    event_actor_detector_chain = (
        llm_generation_chain(
            "detect_entities.prompt",
            engine=args.engine,
            max_tokens=1000,
        )
        | extract_tag_from_llm_output.bind(tag="entities")
        | string_to_list
    )

    actor_extraction_chain = (
        llm_generation_chain(
            "extract_entities.prompt",
            engine=args.engine,
            max_tokens=2000,
            pydantic_class=AllEntitySpans,
        )
        | process_entity_output
    )

    asyncio.run(main(args))

    write_prompt_logs_to_file()
    print(f"Total LLM cost: ${get_total_cost():.1f}")
